import math
from typing import Dict, Optional, List, Tuple

import numpy as np
import pandas as pd
import pingouin as pg
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.gridspec import GridSpec

from config import Config, FiguresConfig, DataConfig
from dataProcessing import DataProcessing
from utils import Utils


class OtherVisualizationCreator:
    # region 01 Pure Trial Vis
    @staticmethod
    def create_vis_for_trial_answers(sp: bool = False, hp: bool = False, kde: bool = False):
        data_per_parts = {
            1: {
                "data":  DataProcessing.get_trials_answers(part=1),
                "key": "size",
                "x": "tilt",
                "y": "distance"
            },
            2: {
                "data":  DataProcessing.get_trials_answers(part=2),
                "key": "tilt",
                "x": "distance",
                "y": "size"
            },
            3: {
                "data":  DataProcessing.get_trials_answers(part=3),
                "key": "distance",
                "x": "size",
                "y": "tilt"
            }
        }

        vis_to_create = {
            "scatterplot": sp,
            "histplot": hp,
            "kde": kde,
        }
        keys_to_use = {
            "area": [
                None,  # for all
                "Ceiling",
                "Floor"
            ],
            "content": [
                None,  # for all
                "Low",
                "Medium"
            ]
        }

        for vis, render in vis_to_create.items():
            if not render:
                continue

            for other_key, values in keys_to_use.items():
                for other_key_value in values:
                    fig_size = (15, 9)
                    margin = FiguresConfig.calc_margin_complete(fig_size=fig_size, absolute_margin=0.75)
                    spacing = FiguresConfig.calc_spacing_complete(fig_size=fig_size, absolute_spacing=0.35)

                    fig = plt.figure(
                        figsize=fig_size,
                        dpi=FiguresConfig.Dpi
                    )
                    gs = GridSpec(
                        nrows=3,
                        ncols=5,
                        figure=fig,
                        left=margin["left"],
                        top=margin["top"],
                        right=margin["right"],
                        bottom=margin["bottom"],
                        wspace=fig_size[0] * spacing["vertical"],
                        hspace=fig_size[1] * spacing["horizontal"],
                    )

                    for p in range(1, 4):
                        part_data = data_per_parts[p]
                        orig_data = part_data["data"]

                        for i in range(5):
                            data = orig_data

                            ax = fig.add_subplot(gs[p - 1, i])

                            if other_key_value is None:
                                tmp_data = data[(data["fixedValueIndex"] == i)]
                            else:
                                tmp_data = data[(data["fixedValueIndex"] == i) & (data[other_key] == other_key_value)]

                            if vis == "kde":
                                OtherVisualizationCreator._trial_answers_kde(ax, tmp_data, part_data, other_key)
                            if vis == "histplot":
                                OtherVisualizationCreator._trial_answers_histplot(ax, tmp_data, part_data, other_key)
                            if vis == "scatterplot":
                                OtherVisualizationCreator._trial_answers_scatterplot(ax, tmp_data, part_data, other_key)

                    fig.tight_layout()
                    for format_str in FiguresConfig.VisFormats:
                        fig.savefig(f"{Config.FigureOutputPath}/01 trialAnswer {vis} {other_key}={other_key_value}.{format_str}")
                    plt.close(fig)

    @staticmethod
    def _trial_answers_scatterplot(ax: plt.Axes, data: pd.DataFrame, part_data: Dict, other_key: str):
        with sns.axes_style(FiguresConfig.SnsStyle):
            vis = sns.scatterplot(
                ax=ax,
                data=data,
                x=part_data["x"],
                y=part_data["y"],
                hue=other_key,
                s=50,
                alpha=.2,
                hue_order=FiguresConfig.Palettes[other_key].keys(),
                palette=FiguresConfig.Palettes[other_key]
            )

        x_range = DataConfig.FixedValues[part_data["x"]][-1] - DataConfig.FixedValues[part_data["x"]][0]
        y_range = DataConfig.FixedValues[part_data["y"]][-1] - DataConfig.FixedValues[part_data["y"]][0]

        ax.set_xlim(
            left=DataConfig.FixedValues[part_data["x"]][0] - x_range * FiguresConfig.RangeLimitFactor,
            right=DataConfig.FixedValues[part_data["x"]][-1] + x_range * FiguresConfig.RangeLimitFactor
        )
        ax.set_ylim(
            bottom=DataConfig.FixedValues[part_data["y"]][0] - y_range * FiguresConfig.RangeLimitFactor,
            top=DataConfig.FixedValues[part_data["y"]][-1] + y_range * FiguresConfig.RangeLimitFactor
        )
        ax.set_xticks(DataConfig.FixedValues[part_data["x"]])
        ax.set_yticks(DataConfig.FixedValues[part_data["y"]])

        if ax.get_legend() is not None:
            ax.get_legend().remove()

    @staticmethod
    def _trial_answers_kde(ax: plt.Axes, data: pd.DataFrame, part_data: Dict, other_key: str):
        with sns.axes_style(FiguresConfig.SnsStyle):
            cmap = sns.cubehelix_palette(start=0, light=1, as_cmap=True)

            vis = sns.kdeplot(
                ax=ax,
                data=data,
                x=part_data["x"],
                y=part_data["y"],
                fill=True,
                cmap=cmap
            )

        x_range = DataConfig.FixedValues[part_data["x"]][-1] - DataConfig.FixedValues[part_data["x"]][0]
        y_range = DataConfig.FixedValues[part_data["y"]][-1] - DataConfig.FixedValues[part_data["y"]][0]

        ax.set_xlim(
            left=DataConfig.FixedValues[part_data["x"]][0] - x_range * FiguresConfig.RangeLimitFactor,
            right=DataConfig.FixedValues[part_data["x"]][-1] + x_range * FiguresConfig.RangeLimitFactor
        )
        ax.set_ylim(
            bottom=DataConfig.FixedValues[part_data["y"]][0] - y_range * FiguresConfig.RangeLimitFactor,
            top=DataConfig.FixedValues[part_data["y"]][-1] + y_range * FiguresConfig.RangeLimitFactor
        )
        ax.set_xticks(DataConfig.FixedValues[part_data["x"]])
        ax.set_yticks(DataConfig.FixedValues[part_data["y"]])

        if ax.get_legend() is not None:
            ax.get_legend().remove()

    @staticmethod
    def _trial_answers_histplot(ax: plt.Axes, data: pd.DataFrame, part_data: Dict, other_key: str):
        with sns.axes_style(FiguresConfig.SnsStyle):
            vis = sns.histplot(
                ax=ax,
                data=data,
                x=part_data["x"],
                y=part_data["y"],
                bins=16,
                binrange=(
                    (
                        DataConfig.FixedValues[part_data["x"]][0],
                        DataConfig.FixedValues[part_data["x"]][-1]
                    ),
                    (
                        DataConfig.FixedValues[part_data["y"]][0],
                        DataConfig.FixedValues[part_data["y"]][-1]
                    ),
                ),
                alpha=.5,
                # thresh=None,
                hue=other_key,
                hue_order=FiguresConfig.Palettes[other_key].keys(),
                palette=FiguresConfig.Palettes[other_key]
            )

        x_range = DataConfig.FixedValues[part_data["x"]][-1] - DataConfig.FixedValues[part_data["x"]][0]
        y_range = DataConfig.FixedValues[part_data["y"]][-1] - DataConfig.FixedValues[part_data["y"]][0]

        ax.set_xlim(
            left=DataConfig.FixedValues[part_data["x"]][0] - x_range * FiguresConfig.RangeLimitFactor,
            right=DataConfig.FixedValues[part_data["x"]][-1] + x_range * FiguresConfig.RangeLimitFactor
        )
        ax.set_ylim(
            bottom=DataConfig.FixedValues[part_data["y"]][0] - y_range * FiguresConfig.RangeLimitFactor,
            top=DataConfig.FixedValues[part_data["y"]][-1] + y_range * FiguresConfig.RangeLimitFactor
        )
        ax.set_xticks(DataConfig.FixedValues[part_data["x"]])
        ax.set_yticks(DataConfig.FixedValues[part_data["y"]])

        if ax.get_legend() is not None:
            ax.get_legend().remove()
    # endregion

    # region 02 Angular Size Values Views 1
    @staticmethod
    def create_vis_for_other_params(kde: bool = False, viewing_angle: bool = False, angular_size: bool = False, duration: bool = False):
        vis_to_create = {
            "kde": kde
        }
        key_to_use = {
            "angularSize": angular_size,
            "viewingAngle": viewing_angle,
            "duration": duration
        }

        data = DataProcessing.get_trials_answers()
        for vis, r1 in vis_to_create.items():
            for key, r2 in key_to_use.items():
                if not r1 or not r2:
                    continue

                with sns.axes_style(FiguresConfig.SnsStyle):
                    grid = sns.FacetGrid(
                        data=data,
                        col="content",
                        row="area",
                    )
                    grid.fig.set_size_inches(8, 8)
                    grid.fig.set_dpi(FiguresConfig.Dpi)

                    if vis == "kde":
                        OtherVisualizationCreator._angular_values_kde(grid, key)

                if grid is None:
                    continue

                grid.fig.tight_layout()
                for format_str in FiguresConfig.VisFormats:
                    grid.fig.savefig(f"{Config.FigureOutputPath}/02 {key} {vis}.{format_str}")
                plt.close(grid.fig)

    @staticmethod
    def _angular_values_kde(grid: sns.FacetGrid, key: str):
        cmap = sns.cubehelix_palette(start=0, light=1, as_cmap=True)

        grid.map(
            sns.kdeplot,
            key,
            fill=True,
            cmap=cmap
        )
        grid.set_axis_labels(y_var="density")

        if key in DataConfig.FixedValues:
            x_range = DataConfig.FixedValues[key][-1] - DataConfig.FixedValues[key][0]
            grid.set(
                xlim=(
                    DataConfig.FixedValues[key][0] - x_range * FiguresConfig.RangeLimitFactor,
                    DataConfig.FixedValues[key][-1] + x_range * FiguresConfig.RangeLimitFactor
                ),
                xticks=DataConfig.FixedValues[key]
            )
    # endregion

    # region 03 Angular Size to Viewing Angle
    @staticmethod
    def create_vis_angular_size_to_viewing_angle():
        data = DataProcessing.get_trials_answers()

        with sns.axes_style(FiguresConfig.SnsStyle):
            grid = sns.FacetGrid(
                data=data,
                col="content",
                row="area",
            )
            grid.fig.set_size_inches(8, 8)
            grid.fig.set_dpi(FiguresConfig.Dpi)

            cmap = sns.cubehelix_palette(start=0, light=1, as_cmap=True)
            grid.map(
                sns.kdeplot,
                "angularSize",
                "viewingAngle",
                fill=True,
                cmap=cmap
            )
        x_range = DataConfig.FixedValues["angularSize"][-1] - DataConfig.FixedValues["angularSize"][0]
        y_range = DataConfig.FixedValues["viewingAngle"][-1] - DataConfig.FixedValues["viewingAngle"][0]
        grid.set(
            xlim=(
                DataConfig.FixedValues["angularSize"][0] - x_range * FiguresConfig.RangeLimitFactor,
                DataConfig.FixedValues["angularSize"][-1] + x_range * FiguresConfig.RangeLimitFactor
            ),
            ylim=(
                DataConfig.FixedValues["viewingAngle"][0] - y_range * FiguresConfig.RangeLimitFactor,
                DataConfig.FixedValues["viewingAngle"][-1] + y_range * FiguresConfig.RangeLimitFactor
            ),
            xticks=DataConfig.FixedValues["angularSize"],
            yticks=DataConfig.FixedValues["viewingAngle"]
        )

        grid.fig.tight_layout()
        for format_str in FiguresConfig.VisFormats:
            grid.fig.savefig(f"{Config.FigureOutputPath}/03 angularSize to viewingAngle.{format_str}")
        plt.close(grid.fig)
    # endregion

    # region 04 Angular Size Values Views 2
    @staticmethod
    def create_vis_for_angular_size_to_parameters(kde: bool = False, viewing_angle: bool = False, angular_size: bool = False):
        vis_to_create = {
            "kde": kde
        }
        key_to_use = {
            "angularSize": angular_size,
            "viewingAngle": viewing_angle,
        }

        data = DataProcessing.get_trials_answers()
        for vis, r1 in vis_to_create.items():
            for key, r2 in key_to_use.items():
                if not r1 or not r2:
                    continue

                fig_size = (12, 9)
                margin = FiguresConfig.calc_margin_complete(fig_size=fig_size, absolute_margin=0.65)
                spacing = FiguresConfig.calc_spacing_complete(fig_size=fig_size, absolute_spacing=0.35)

                fig = plt.figure(
                    figsize=fig_size,
                    dpi=FiguresConfig.Dpi
                )
                gs = GridSpec(
                    nrows=3,
                    ncols=4,
                    figure=fig,
                    left=margin["left"],
                    top=margin["top"],
                    right=margin["right"],
                    bottom=margin["bottom"],
                    wspace=fig_size[0] * spacing["vertical"],
                    hspace=fig_size[1] * spacing["horizontal"],
                )

                for ki, para_key in enumerate(["distance", "tilt", "size"]):
                    for ai, a in enumerate(DataConfig.KeyValues["area"]):
                        for cti, ct in enumerate(DataConfig.KeyValues["content"]):
                            tmp_data = data[(data["area"] == a) & (data["content"] == ct)]

                            ax = fig.add_subplot(gs[ki, ai * 2 + cti])

                            if vis == "kde":
                                OtherVisualizationCreator._angular_values_to_parameter_kde(ax, tmp_data, key, para_key)

                fig.tight_layout()
                for format_str in FiguresConfig.VisFormats:
                    fig.savefig(f"{Config.FigureOutputPath}/04 {key} for parameters {vis}.{format_str}")
                plt.close(fig)

    @staticmethod
    def _angular_values_to_parameter_kde(ax: plt.Axes, data: pd.DataFrame, key: str, other_key: str):
        with sns.axes_style(FiguresConfig.SnsStyle):
            cmap = sns.cubehelix_palette(start=0, light=1, rot=-.4, as_cmap=True)

            vis = sns.kdeplot(
                ax=ax,
                data=data,
                x=key,
                y=other_key,
                fill=True,
                cmap=cmap
            )

        x_range = DataConfig.FixedValues[key][-1] - DataConfig.FixedValues[key][0]
        y_range = DataConfig.FixedValues[other_key][-1] - DataConfig.FixedValues[other_key][0]

        ax.set_xlim(
            left=DataConfig.FixedValues[key][0] - x_range * FiguresConfig.RangeLimitFactor,
            right=DataConfig.FixedValues[key][-1] + x_range * FiguresConfig.RangeLimitFactor
        )
        ax.set_ylim(
            bottom=DataConfig.FixedValues[other_key][0] - y_range * FiguresConfig.RangeLimitFactor,
            top=DataConfig.FixedValues[other_key][-1] + y_range * FiguresConfig.RangeLimitFactor
        )
        ax.set_xticks(DataConfig.FixedValues[key])
        ax.set_yticks(DataConfig.FixedValues[other_key])

        if ax.get_legend() is not None:
            ax.get_legend().remove()
    # endregion

    # region 05 Angular Size Values Views 3 (Participants)
    @staticmethod
    def create_vis_for_other_params_to_participants(bar: bool = False, viewing_angle: bool = False, angular_size: bool = False, duration: bool = False, aggregation: str = "mean"):
        vis_to_create = {
            "barchart": bar
        }
        key_to_use = {
            "angularSize": angular_size,
            "viewingAngle": viewing_angle,
            "duration": duration
        }

        data = DataProcessing.get_trials_answers()
        for vis, r1 in vis_to_create.items():
            for key, r2 in key_to_use.items():
                if not r1 or not r2:
                    continue

                fig_size = (15, 4)
                margin = FiguresConfig.calc_margin_complete(fig_size=fig_size, absolute_margin=0.65)
                spacing = FiguresConfig.calc_spacing_complete(fig_size=fig_size, absolute_spacing=0.35)

                fig = plt.figure(
                    figsize=fig_size,
                    dpi=FiguresConfig.Dpi
                )
                gs = GridSpec(
                    nrows=1,
                    ncols=1,
                    figure=fig,
                    left=margin["left"],
                    top=margin["top"],
                    right=margin["right"],
                    bottom=margin["bottom"],
                    wspace=fig_size[0] * spacing["vertical"],
                    hspace=fig_size[1] * spacing["horizontal"],
                )

                tmp_data = data[["participant", "area", "content", "conditionKey", key]]
                tmp_data = tmp_data.groupby(
                    ["participant", "area", "content", "conditionKey"],
                    as_index=False
                )
                tmp_data = Utils.aggregate(tmp_data, aggregation)

                if vis == "barchart":
                    ax = fig.add_subplot(gs[0, 0])
                    OtherVisualizationCreator._angular_values_for_participants_bar(ax, tmp_data, key)

                fig.tight_layout()
                for format_str in FiguresConfig.VisFormats:
                    fig.savefig(f"{Config.FigureOutputPath}/05 {key} for parameters {vis}-{aggregation}.{format_str}")
                plt.close(fig)

    @staticmethod
    def _angular_values_for_participants_bar(ax: plt.Axes, data: pd.DataFrame, key: str):
        with sns.axes_style(FiguresConfig.SnsStyle):
            vis = sns.barplot(
                ax=ax,
                data=data,
                x=data["participant"],
                y=data[key],
                palette=FiguresConfig.Palettes["conditionKey"],
                hue="conditionKey",
                hue_order=FiguresConfig.Palettes["conditionKey"].keys(),
            )

        if key in DataConfig.FixedValues:
            y_range = DataConfig.FixedValues[key][-1] - DataConfig.FixedValues[key][0]

            ax.set_ylim(
                bottom=DataConfig.FixedValues[key][0],
                top=DataConfig.FixedValues[key][-1] + y_range * FiguresConfig.RangeLimitFactor
            )
            ax.set_yticks(DataConfig.FixedValues[key])
    # endregion

    # region 06 Height to Angular Values 1
    @staticmethod
    def create_vis_for_height_to_parameters(sp: bool = False, swarm: bool = False, strip: bool = False, vp: bool = False):
        vis_to_create = {
            "scatterplot": sp,
            "swarmplot": swarm,
            "stripplot": strip,
            "violinplot": vp
        }

        q_data = DataProcessing.get_questionnaire_data()
        data = DataProcessing.get_trials_answers()
        for vis, render in vis_to_create.items():
            if not render:
                continue

            fig_size = (24, 15)
            margin = FiguresConfig.calc_margin_complete(fig_size=fig_size, absolute_margin=0.65)
            spacing = FiguresConfig.calc_spacing_complete(fig_size=fig_size, absolute_spacing=0.25)

            fig = plt.figure(
                figsize=fig_size,
                dpi=FiguresConfig.Dpi
            )
            gs = GridSpec(
                nrows=5,
                ncols=4,
                figure=fig,
                left=margin["left"],
                top=margin["top"],
                right=margin["right"],
                bottom=margin["bottom"],
                wspace=fig_size[0] * spacing["vertical"],
                hspace=fig_size[1] * spacing["horizontal"],
            )

            for ki, para_key in enumerate(["angularSize", "viewingAngle", "distance", "tilt", "size"]):
                for ai, a in enumerate(DataConfig.KeyValues["area"]):
                    for cti, ct in enumerate(DataConfig.KeyValues["content"]):
                        tmp_data = data[(data["area"] == a) & (data["content"] == ct)]
                        tmp_data["height"] = tmp_data.apply(lambda row: q_data[q_data["participant"] == row["participant"]]["Height"].values[0], axis=1)

                        ax = fig.add_subplot(gs[ki, ai * 2 + cti])

                        if vis == "scatterplot":
                            OtherVisualizationCreator._height_to_value_scatterplot(ax, tmp_data, para_key)
                        if vis == "swarmplot":
                            OtherVisualizationCreator._height_to_value_swarmplot(ax, tmp_data, para_key)
                        if vis == "stripplot":
                            OtherVisualizationCreator._height_to_value_stripplot(ax, tmp_data, para_key)
                        if vis == "violinplot":
                            OtherVisualizationCreator._height_to_value_violinplot(ax, tmp_data, para_key)

            fig.tight_layout()
            for format_str in FiguresConfig.VisFormats:
                fig.savefig(f"{Config.FigureOutputPath}/06 height for parameters {vis}.{format_str}")
            plt.close(fig)

    @staticmethod
    def _height_to_value_scatterplot(ax: plt.Axes, data: pd.DataFrame, other_key: str):
        with sns.axes_style(FiguresConfig.SnsStyle):
            vis = sns.scatterplot(
                ax=ax,
                data=data,
                x="height",
                y=other_key,
                # s=3,
                # alpha=.25
            )

        x_range = DataConfig.FixedValues["height"][-1] - DataConfig.FixedValues["height"][0]
        y_range = DataConfig.FixedValues[other_key][-1] - DataConfig.FixedValues[other_key][0]

        ax.set_xlim(
            left=DataConfig.FixedValues["height"][0] - x_range * FiguresConfig.RangeLimitFactor,
            right=DataConfig.FixedValues["height"][-1] + x_range * FiguresConfig.RangeLimitFactor
        )
        ax.set_ylim(
            bottom=DataConfig.FixedValues[other_key][0] - y_range * FiguresConfig.RangeLimitFactor,
            top=DataConfig.FixedValues[other_key][-1] + y_range * FiguresConfig.RangeLimitFactor
        )
        ax.set_xticks(DataConfig.FixedValues["height"])
        ax.set_yticks(DataConfig.FixedValues[other_key])

        if ax.get_legend() is not None:
            ax.get_legend().remove()

    @staticmethod
    def _height_to_value_swarmplot(ax: plt.Axes, data: pd.DataFrame, other_key: str):
        with sns.axes_style(FiguresConfig.SnsStyle):
            vis = sns.swarmplot(
                ax=ax,
                data=data,
                x="height",
                y=other_key,
                s=2,
                # alpha=.5
            )

        y_range = DataConfig.FixedValues[other_key][-1] - DataConfig.FixedValues[other_key][0]

        ax.set_ylim(
            bottom=DataConfig.FixedValues[other_key][0] - y_range * FiguresConfig.RangeLimitFactor,
            top=DataConfig.FixedValues[other_key][-1] + y_range * FiguresConfig.RangeLimitFactor
        )
        ax.set_yticks(DataConfig.FixedValues[other_key])

        if ax.get_legend() is not None:
            ax.get_legend().remove()

    @staticmethod
    def _height_to_value_stripplot(ax: plt.Axes, data: pd.DataFrame, other_key: str):
        with sns.axes_style(FiguresConfig.SnsStyle):
            vis = sns.stripplot(
                ax=ax,
                data=data,
                x="height",
                y=other_key,
                alpha=.25
            )

        y_range = DataConfig.FixedValues[other_key][-1] - DataConfig.FixedValues[other_key][0]

        ax.set_ylim(
            bottom=DataConfig.FixedValues[other_key][0] - y_range * FiguresConfig.RangeLimitFactor,
            top=DataConfig.FixedValues[other_key][-1] + y_range * FiguresConfig.RangeLimitFactor
        )
        ax.set_yticks(DataConfig.FixedValues[other_key])

        if ax.get_legend() is not None:
            ax.get_legend().remove()

    @staticmethod
    def _height_to_value_violinplot(ax: plt.Axes, data: pd.DataFrame, other_key: str):
        with sns.axes_style(FiguresConfig.SnsStyle):
            vis = sns.violinplot(
                ax=ax,
                data=data,
                x="height",
                y=other_key,
                orient="v",
                inner="points"
            )

        y_range = DataConfig.FixedValues[other_key][-1] - DataConfig.FixedValues[other_key][0]

        ax.set_ylim(tuple([
            DataConfig.FixedValues[other_key][0] - y_range * FiguresConfig.RangeLimitFactor,
            DataConfig.FixedValues[other_key][-1] + y_range * FiguresConfig.RangeLimitFactor
        ]))
        ax.set_yticks(DataConfig.FixedValues[other_key])

        if ax.get_legend() is not None:
            ax.get_legend().remove()
    # endregion

    # region 07 Height to Angular Values 2
    @staticmethod
    def create_vis_for_height_to_parameters_lm(angular_size: bool = False, viewing_angle: bool = False, distance: bool = False, tilt: bool = False, size: bool = False):
        q_data = DataProcessing.get_questionnaire_data()
        data = DataProcessing.get_trials_answers()

        key_to_use = {
            "angularSize": angular_size,
            "viewingAngle": viewing_angle,
            "distance": distance,
            "tilt": tilt,
            "size": size
        }

        for other_key, render in key_to_use.items():
            tmp_data = data.copy()
            tmp_data["height"] = tmp_data.apply(lambda row: q_data[q_data["participant"] == row["participant"]]["Height"].values[0], axis=1)

            vis = OtherVisualizationCreator._height_to_value_lmplot(tmp_data, other_key)

            vis.fig.tight_layout()
            for format_str in FiguresConfig.VisFormats:
                vis.fig.savefig(f"{Config.FigureOutputPath}/07 height lmplot for {other_key}.{format_str}")
            plt.close(vis.fig)

    @staticmethod
    def _height_to_value_lmplot(data: pd.DataFrame, other_key: str):
        with sns.axes_style(FiguresConfig.SnsStyle):
            vis = sns.lmplot(
                data=data,
                x="height",
                y=other_key,
                col="conditionKey",
                # alpha=.25
            )

        x_range = DataConfig.FixedValues["height"][-1] - DataConfig.FixedValues["height"][0]
        y_range = DataConfig.FixedValues[other_key][-1] - DataConfig.FixedValues[other_key][0]
        for i in range(4):
            ax = vis.axes[0, i]
            ax.set_xlim(
                left=DataConfig.FixedValues["height"][0] - x_range * FiguresConfig.RangeLimitFactor,
                right=DataConfig.FixedValues["height"][-1] + x_range * FiguresConfig.RangeLimitFactor
            )
            ax.set_ylim(
                bottom=DataConfig.FixedValues[other_key][0] - y_range * FiguresConfig.RangeLimitFactor,
                top=DataConfig.FixedValues[other_key][-1] + y_range * FiguresConfig.RangeLimitFactor
            )
            # ax.set_xticks(DataConfig.FixedValues["height"])
            ax.set_yticks(DataConfig.FixedValues[other_key])

        return vis
    # endregion

    # region 08 Parts for Parameters
    @staticmethod
    def create_vis_for_parts_and_parameters(part1: bool = False, part2: bool = False, part3: bool = False):
        data_per_parts = {
            1: {
                "data":  DataProcessing.get_trials_answers(part=1),
                "x": "tilt",
                "y": "distance",
                "render": part1
            },
            2: {
                "data":  DataProcessing.get_trials_answers(part=2),
                "x": "distance",
                "y": "size",
                "render": part2
            },
            3: {
                "data":  DataProcessing.get_trials_answers(part=3),
                "x": "size",
                "y": "tilt",
                "render": part3
            }
        }

        for part, data_object in data_per_parts.items():
            if not data_object["render"]:
                continue

            with sns.axes_style(FiguresConfig.SnsStyle):
                grid = sns.FacetGrid(
                    data=data_object["data"],
                    col="fixedValueIndex",
                    row="conditionKey",
                )
                grid.fig.set_size_inches(25, 20)
                grid.fig.set_dpi(FiguresConfig.Dpi)

                OtherVisualizationCreator._parts_per_parameter_kde(grid, data_object["x"], data_object["y"])

            if grid is None:
                continue

            grid.fig.tight_layout()
            for format_str in FiguresConfig.VisFormats:
                grid.fig.savefig(f"{Config.FigureOutputPath}/08 part {part}.{format_str}")
            plt.close(grid.fig)

    @staticmethod
    def _parts_per_parameter_kde(grid, x: str, y: str):
        cmap = sns.cubehelix_palette(start=0, light=1, as_cmap=True)

        grid.map(
            sns.kdeplot,
            x,
            y,
            fill=True,
            cmap=cmap
        )
        x_range = DataConfig.FixedValues[x][-1] - DataConfig.FixedValues[x][0]
        y_range = DataConfig.FixedValues[y][-1] - DataConfig.FixedValues[y][0]
        grid.set(
            xlim=(
                DataConfig.FixedValues[x][0] - x_range * FiguresConfig.RangeLimitFactor,
                DataConfig.FixedValues[x][-1] + x_range * FiguresConfig.RangeLimitFactor
            ),
            ylim=(
                DataConfig.FixedValues[y][0] - y_range * FiguresConfig.RangeLimitFactor,
                DataConfig.FixedValues[y][-1] + y_range * FiguresConfig.RangeLimitFactor
            ),
            xticks=DataConfig.FixedValues[x],
            yticks=DataConfig.FixedValues[y]
        )
    # endregion

    # region 09 Parts for Angular Values
    @staticmethod
    def create_vis_for_parts_on_angular_values():
        data = DataProcessing.get_trials_answers()

        with sns.axes_style(FiguresConfig.SnsStyle):
            grid = sns.FacetGrid(
                data=data,
                col="part",
                row="conditionKey",
            )
            grid.fig.set_size_inches(12, 16)
            grid.fig.set_dpi(FiguresConfig.Dpi)

            OtherVisualizationCreator._parts_and_angular_value_kde(grid)

        grid.fig.tight_layout()
        for format_str in FiguresConfig.VisFormats:
            grid.fig.savefig(f"{Config.FigureOutputPath}/09 part and angular values.{format_str}")
        plt.close(grid.fig)

    @staticmethod
    def _parts_and_angular_value_kde(grid):
        cmap = sns.cubehelix_palette(start=0, light=1, as_cmap=True)

        x = "angularSize"
        y = "viewingAngle"
        grid.map(
            sns.kdeplot,
            x,
            y,
            fill=True,
            cmap=cmap
        )
        x_range = DataConfig.FixedValues[x][-1] - DataConfig.FixedValues[x][0]
        y_range = DataConfig.FixedValues[y][-1] - DataConfig.FixedValues[y][0]
        grid.set(
            xlim=(
                DataConfig.FixedValues[x][0] - x_range * FiguresConfig.RangeLimitFactor,
                DataConfig.FixedValues[x][-1] + x_range * FiguresConfig.RangeLimitFactor
            ),
            ylim=(
                DataConfig.FixedValues[y][0] - y_range * FiguresConfig.RangeLimitFactor,
                DataConfig.FixedValues[y][-1] + y_range * FiguresConfig.RangeLimitFactor
            ),
            xticks=DataConfig.FixedValues[x],
            yticks=DataConfig.FixedValues[y]
        )
    # endregion

    # region 10 Course of Input Changes
    @staticmethod
    def create_vis_input_course_per_participant(participants: List[int], parameters: bool = False, angular_values: bool = False):
        if participants is None:
            return

        key_to_use = {
            "parameters": parameters,
            "angularValues": angular_values,
        }

        for participant in participants:
            input_data = DataProcessing.get_participant_input(participant)

            for key_class, render in key_to_use.items():
                if not render:
                    continue

                print(f"Working on: {participant} {key_class}")

                fig_size = (36, 30)
                margin = FiguresConfig.calc_margin_complete(fig_size=fig_size, absolute_margin=0.7)
                spacing = FiguresConfig.calc_spacing_complete(fig_size=fig_size, absolute_spacing=0.6)

                fig = plt.figure(
                    figsize=fig_size,
                    dpi=FiguresConfig.Dpi
                )
                gs = GridSpec(
                    nrows=12,
                    ncols=10,
                    figure=fig,
                    left=margin["left"],
                    top=margin["top"],
                    right=margin["right"],
                    bottom=margin["bottom"],
                    wspace=fig_size[0] * spacing["vertical"],
                    hspace=fig_size[1] * spacing["horizontal"],
                )

                start_indices = list(input_data[input_data["timeDelta"] == 0.0].index)
                start_indices.append(-1)
                for i, start_index in enumerate(start_indices):
                    if i == 0:
                        continue

                    if i != 120:
                        data = input_data.iloc[start_indices[i - 1]:start_index, :]
                    else:
                        data = input_data.iloc[start_indices[i - 1]:, :]

                    if data.shape[0] <= 1:
                        continue

                    y = int((i - 1) / 10)
                    x = (i - 1) - (y * 10)

                    if key_class == "parameters":
                        if data["part"].values[0] == 1:
                            y_key1 = "distance"
                            y_key2 = "tilt"
                        elif data["part"].values[0] == 2:
                            y_key1 = "size"
                            y_key2 = "distance"
                        elif data["part"].values[0] == 3:
                            y_key1 = "tilt"
                            y_key2 = "size"
                    elif key_class == "angularValues":
                        y_key1 = "angularSize"
                        y_key2 = "viewingAngle"
                    else:
                        return

                    ax1 = fig.add_subplot(gs[y, x])
                    ax2 = ax1.twinx()
                    OtherVisualizationCreator._participant_input_linechart(ax1, data, y_key1)
                    OtherVisualizationCreator._participant_input_linechart(ax2, data, y_key2)

                fig.tight_layout()
                for format_str in FiguresConfig.VisFormats:
                    fig.savefig(f"{Config.FigureOutputPath}/10 participant input for {key_class} - {participant}.{format_str}")
                plt.close(fig)

    @staticmethod
    def _participant_input_linechart(ax: plt.Axes, data: pd.DataFrame, y_key: str):
        with sns.axes_style(FiguresConfig.SnsStyle):
            vis = sns.lineplot(
                ax=ax,
                data=data,
                x="timeDelta",
                y=y_key,
                color=FiguresConfig.Palettes["parameter"][y_key]
            )

        y_range = DataConfig.FixedValues[y_key][-1] - DataConfig.FixedValues[y_key][0]
        ax.set_ylim(
            bottom=DataConfig.FixedValues[y_key][0] - y_range * FiguresConfig.RangeLimitFactor,
            top=DataConfig.FixedValues[y_key][-1] + y_range * FiguresConfig.RangeLimitFactor
        )
        ax.set_yticks(DataConfig.FixedValues[y_key])

        if ax.get_legend() is not None:
            ax.get_legend().remove()
    # endregion

    # region 11 parameter over Fixed Values
    @staticmethod
    def create_vis_parameters_over_fixed_values(part1: bool = False, part2: bool = False, part3: bool = False):
        data_per_parts = {
            1: {
                "data":  DataProcessing.get_trials_answers(part=1),
                "keys": [
                    "tilt",
                    "distance",
                    "angularSize",
                    "viewingAngle"
                ],
                "render": part1
            },
            2: {
                "data":  DataProcessing.get_trials_answers(part=2),
                "keys": [
                    "size",
                    "distance",
                    "angularSize",
                    "viewingAngle"
                ],
                "render": part2
            },
            3: {
                "data":  DataProcessing.get_trials_answers(part=3),
                "keys": [
                    "tilt",
                    "size",
                    "angularSize",
                    "viewingAngle"
                ],
                "render": part3
            }
        }

        for part, data_object in data_per_parts.items():
            if not data_object["render"]:
                continue

            fig_size = (12, 12)
            margin = FiguresConfig.calc_margin_complete(fig_size=fig_size, absolute_margin=0.7)
            spacing = FiguresConfig.calc_spacing_complete(fig_size=fig_size, absolute_spacing=0.35)

            fig = plt.figure(
                figsize=fig_size,
                dpi=FiguresConfig.Dpi
            )
            gs = GridSpec(
                nrows=4,
                ncols=1,
                figure=fig,
                left=margin["left"],
                top=margin["top"],
                right=margin["right"],
                bottom=margin["bottom"],
                wspace=fig_size[0] * spacing["vertical"],
                hspace=fig_size[1] * spacing["horizontal"],
            )

            for i, key in enumerate(data_object["keys"]):
                ax = fig.add_subplot(gs[i, 0])

                OtherVisualizationCreator._part_over_fixed_values_linechart(ax, data_object["data"], key)

            Utils.set_legend_from_vis(fig, ax, True)

            fig.tight_layout()
            for format_str in FiguresConfig.VisFormats:
                fig.savefig(f"{Config.FigureOutputPath}/11 fixed values part {part}.{format_str}")
            plt.close(fig)

    @staticmethod
    def _part_over_fixed_values_linechart(ax: plt.Axes, data: pd.DataFrame, key: str):
        with sns.axes_style(FiguresConfig.SnsStyle):
            vis = sns.lineplot(
                ax=ax,
                data=data,
                x="fixedValueIndex",
                y=key,
                hue="conditionKey",
                hue_order=FiguresConfig.Palettes["conditionKey"].keys(),
                palette=FiguresConfig.Palettes["conditionKey"],
                markers=True,
                err_style=None
                # estimator=None
            )

        x_range = DataConfig.FixedValues["fixedValueIndex"][-1] - DataConfig.FixedValues["fixedValueIndex"][0]
        y_range = DataConfig.FixedValues[key][-1] - DataConfig.FixedValues[key][0]

        ax.set_xlim(
            left=DataConfig.FixedValues["fixedValueIndex"][0] - x_range * FiguresConfig.RangeLimitFactor,
            right=DataConfig.FixedValues["fixedValueIndex"][-1] + x_range * FiguresConfig.RangeLimitFactor
        )
        ax.set_ylim(
            bottom=DataConfig.FixedValues[key][0] - y_range * FiguresConfig.RangeLimitFactor,
            top=DataConfig.FixedValues[key][-1] + y_range * FiguresConfig.RangeLimitFactor
        )
        ax.set_xticks(DataConfig.FixedValues["fixedValueIndex"])
        ax.set_yticks(DataConfig.FixedValues[key])

        if ax.get_legend() is not None:
            ax.get_legend().remove()
    # endregion


if __name__ == '__main__':
    OtherVisualizationCreator.create_vis_for_trial_answers(
        sp=True,
        hp=True,
        kde=True,
    )

    OtherVisualizationCreator.create_vis_for_other_params(
        kde=True,
        angular_size=True,
        viewing_angle=True,
        duration=True
    )

    OtherVisualizationCreator.create_vis_angular_size_to_viewing_angle()

    OtherVisualizationCreator.create_vis_for_angular_size_to_parameters(
        kde=True,
        angular_size=True,
        viewing_angle=True
    )

    OtherVisualizationCreator.create_vis_for_other_params_to_participants(
        bar=True,
        angular_size=True,
        viewing_angle=True,
        duration=True,
        aggregation="mean"
    )

    OtherVisualizationCreator.create_vis_for_height_to_parameters(
        sp=True,
        swarm=True,
        strip=True,
        vp=True
    )

    OtherVisualizationCreator.create_vis_for_height_to_parameters_lm(
        angular_size=True,
        viewing_angle=True,
        distance=True,
        tilt=True,
        size=True
    )

    OtherVisualizationCreator.create_vis_for_parts_and_parameters(
        part1=True,
        part2=True,
        part3=True
    )

    OtherVisualizationCreator.create_vis_for_parts_on_angular_values()

    OtherVisualizationCreator.create_vis_input_course_per_participant(
        participants=list(range(1, 27)),
        parameters=True,
        angular_values=True,
    )

    OtherVisualizationCreator.create_vis_parameters_over_fixed_values(
        part1=True,
        part2=True,
        part3=True
    )
